import jax.numpy as np
import scalevi.var_dists.var_dists_base as var_dists_base
import scalevi.distributions.distributions as dists
import scalevi.distributions.scale_transforms as scale_transforms
import scalevi.utils.utils as utils
import scalevi.nn.encoders as encoders
import jax

def select_id_map(use_test):
    return "test_ids" if use_test else "train_ids"

class BranchGaussian(var_dists_base.VarBranchDist):
    def __init__(
        self, N_chunk, D_par, D_kid,
        scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        self.D_par = D_par
        self.D_kid = D_kid
        self.scale_transform=scale_transform
        super(BranchGaussian, self).__init__(N_chunk, D_par + N_chunk*D_kid)

    def initial_params(self):
        return {
                'μθ' : np.zeros(self.D_par),
                'Lθ' : dists.util.matrix_to_tril_vec(
                            self.scale_transform.inverse(np.eye(self.D_par))),
                'μw' : np.zeros([self.N_chunk, self.D_kid]),
                'Aw' : np.zeros([self.N_chunk, self.D_kid, self.D_par]),
                'Lw' : dists.util.matrix_to_tril_vec(
                        self.scale_transform.inverse(
                            utils.eye_3d(self.D_kid, self.N_chunk))),
                }

    def get_params_parent(self, params):
        return {
                "loc": params['μθ'], 
                "scale_tril": self.scale_transform.forward(
                        dists.util.vec_to_tril_matrix(params['Lθ']))
            }

    def get_params_child(self, params, θ, chunk):
        return {
                "loc": params['μw'][chunk] + params['Aw'][chunk]@θ,
                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                                params['Lw'][chunk]))
            }
    def parent_dist(self, params):
        return dists.MultivariateNormal(
                        **self.get_params(params, "parent", None, None))

    def child_dist(self, θ, params, chunk):
        return dists.MultivariateNormal(
                        **self.get_params(params, "child", θ, chunk)) 

class BranchGaussianWithSampleEval(var_dists_base.VarBranchDistWithSampleEval):
    def __init__(
        self, N_chunk, D_par, D_kid,
        scale_transform=scale_transforms.ProximalScaleTransform(1.0)):
        self.D_par = D_par
        self.D_kid = D_kid
        self.scale_transform=scale_transform
        super(BranchGaussianWithSampleEval, self).__init__(N_chunk, D_par + N_chunk*D_kid)

    def initial_params(self):
        return {
                'μθ' : np.zeros(self.D_par),
                'Lθ' : dists.util.matrix_to_tril_vec(
                            self.scale_transform.inverse(np.eye(self.D_par))),
                'μw' : np.zeros([self.N_chunk, self.D_kid]),
                'Aw' : np.zeros([self.N_chunk, self.D_kid, self.D_par]),
                'Lw' : dists.util.matrix_to_tril_vec(
                        self.scale_transform.inverse(
                            utils.eye_3d(self.D_kid, self.N_chunk))),
                }

    def get_params_parent(self, params, **kwargs):
        return {
                "loc": params['μθ'], 
                "scale_tril": self.scale_transform.forward(
                        dists.util.vec_to_tril_matrix(params['Lθ']))
            }

    def get_params_child(self, params, θ, chunk, **kwargs):
        return {
                "loc": params['μw'][chunk] + params['Aw'][chunk]@θ,
                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                                params['Lw'][chunk]))
            }

    def parent_dist(self, params, **kwargs):
        return dists.CustomMultivariateNormal(
                        **self.get_params(params, "parent", None, None, **kwargs))

    def child_dist(self, θ, params, chunk, **kwargs):
        return dists.CustomMultivariateNormal(
                        **self.get_params(params, "child", θ, chunk, **kwargs)) 

class BranchBlockGaussianWithSampleEval(BranchGaussianWithSampleEval):
    """Class to generate Branch Block Gaussian distribution.
    θ and z distributed independent of each other as a full-rank Gaussian. 
    """
    def initial_params(self):
        return {
                'μθ' : np.zeros(self.D_par),
                'Lθ' : dists.util.matrix_to_tril_vec(
                            self.scale_transform.inverse(np.eye(self.D_par))),
                'μw' : np.zeros([self.N_chunk, self.D_kid]),
                'Lw' : dists.util.matrix_to_tril_vec(
                        self.scale_transform.inverse(
                            utils.eye_3d(self.D_kid, self.N_chunk))),
                }

    def get_params_child(self, params, θ, chunk, **kwargs):
        return {
                "loc": params['μw'][chunk],
                "scale_tril": self.scale_transform.forward(
                                    dists.util.vec_to_tril_matrix(
                                                params['Lw'][chunk]))
            }

class BranchDiagonalWithSampleEval(BranchGaussianWithSampleEval):

    def initial_params(self):
        return {
                'μθ' : np.zeros(self.D_par),
                'σθ' : self.scale_transform.inverse_diag_transform(np.ones(self.D_par)),
                'μw' : np.zeros([self.N_chunk, self.D_kid]),
                'σw' : self.scale_transform.inverse_diag_transform(
                            np.ones((self.N_chunk, self.D_kid))),
                }

    def get_params_parent(self, params, **kwargs):
        return {
                "mu": params['μθ'], 
                "sig": self.scale_transform.forward_diag_transform(
                                    params['σθ'])
            }

    def get_params_child(self, params, θ, chunk, **kwargs):
        return {
                "mu": params['μw'][chunk],
                "sig": self.scale_transform.forward_diag_transform(
                                    params['σw'][chunk])
            }
    def parent_dist(self, params, **kwargs):
        return dists.CustomDiagonalNormal(
                        **self.get_params(params, "parent", None, None, **kwargs))

    def child_dist(self, θ, params, chunk, **kwargs):
        return dists.CustomDiagonalNormal(
                        **self.get_params(params, "child", θ, chunk, **kwargs)) 

class AmortizedBranchGaussianWithSampleEval(BranchGaussianWithSampleEval):
    def __init__(
        self, N_chunk, D_par, D_kid,
        data,
        encoder, 
        model,
        masked_model=False,
        scale_transform=scale_transforms.ProximalScaleTransform(1.0),
        **kwargs):
        self.masked_model = masked_model
        self.data = data
        self.model = model
        self.conditional_model = "Conditional" in self.model
        self.encoder = utils.get_attribute(encoders, encoder)(
                                    **kwargs, 
                                    scale_transform=scale_transform)
        super(AmortizedBranchGaussianWithSampleEval, self).__init__(
                                    N_chunk, D_par, D_kid,
                                    scale_transform=scale_transform)

    def initial_params(self, **kwargs):
        return {'μθ' : np.zeros(self.D_par),
                'Lθ' : dists.util.matrix_to_tril_vec(
                            self.scale_transform.inverse(np.eye(self.D_par))),
                'encoder' : self.encoder.init(**kwargs),
            }

    def _get_encoder_args_mask(self, chunk, **kwargs):
        # if self.masked_model is True:
        if self.model == "RS":
            return {
                "mask": np.take(self.data[select_id_map(kwargs.get('use_test', False))], chunk, 0)>=0
            }
        else:
            return {
                # "mask": np.arange(self.data['x'].shape[1])<=self.data['mask'][chunk]
                "mask": np.ones(self.data['x'].shape[1])
            }
        # return {}

    def _get_encoder_args_x(self, chunk, **kwargs):
        if self.model == "RS":
            idx = np.take(self.data[select_id_map(kwargs.get('use_test', False))], chunk, 0)
            x = np.take(self.data['genome'],
                        np.take(self.data['movies'], idx),
                        0)
            y = np.take(self.data['ratings'], idx)

            return {
                "x": np.append(x, y[:, None], -1) 
            }

        if self.conditional_model:
            return {
                "x": np.append(
                            self.data['x'][chunk],
                            self.data['y'][chunk][:, None],
                            -1) 
            }

        return {'x': self.data['x'][chunk],}

    def _get_encoder_args(self, θ, chunk, **kwargs):
        enc_args = {'θ': θ,}
        enc_args.update(self._get_encoder_args_x(chunk, **kwargs))
        enc_args.update(self._get_encoder_args_mask(chunk, **kwargs))
        return enc_args
    
    def get_params_child(self, params, θ, chunk, **kwargs):
        return self.encoder.apply(params['encoder'],
                            **self._get_encoder_args(θ, chunk, **kwargs))

class AmortizedBranchDiagonalWithSampleEval(AmortizedBranchGaussianWithSampleEval):

    def initial_params(self, **kwargs):
        return {'μθ' : np.zeros(self.D_par),
                'σθ' : self.scale_transform.inverse_diag_transform(np.ones(self.D_par)),
                'encoder' : self.encoder.init(**kwargs),
            }

    def get_params_parent(self, params, **kwargs):
        return {
                "mu": params['μθ'], 
                "sig": self.scale_transform.forward_diag_transform(
                                    params['σθ'])
            }

    def parent_dist(self, params, **kwargs):
        return dists.CustomDiagonalNormal(
                        **self.get_params(params, "parent", None, None, **kwargs))

    def child_dist(self, θ, params, chunk, **kwargs):
        return dists.CustomDiagonalNormal(
                        **self.get_params(params, "child", θ, chunk, **kwargs)) 

class AmortizedBranchBlockGaussianWithSampleEval(AmortizedBranchGaussianWithSampleEval):
    pass

class AmortizedBranchGaussian(BranchGaussian):
    def __init__(
        self, N_chunk, D_par, D_kid,
        data,
        encoder, 
        model,
        masked_model=False,
        scale_transform=scale_transforms.ProximalScaleTransform(1.0),
        **kwargs):
        self.masked_model = masked_model
        self.data = data
        self.model = model
        self.conditional_model = "Conditional" in self.model
        self.encoder = utils.get_attribute(encoders, encoder)(
                                    **kwargs, 
                                    scale_transform=scale_transform)
        super(AmortizedBranchGaussian, self).__init__(
                                    N_chunk, D_par, D_kid,
                                    scale_transform=scale_transform)

    def initial_params(self, **kwargs):
        return {'μθ' : np.zeros(self.D_par),
                'Lθ' : dists.util.matrix_to_tril_vec(
                            self.scale_transform.inverse(np.eye(self.D_par))),
                'encoder' : self.encoder.init(**kwargs),
            }

    def _get_encoder_args_mask(self, chunk):
        if self.masked_model is True:
            if self.model == "RS":
                return {
                    "mask": np.take(self.data['users_metadata'], chunk, 0)>=0
                }
            else:
                return {
                    "mask": np.arange(self.data['x'].shape[1])<=self.data['mask'][chunk]
                }
        return {}

    def _get_encoder_args_x(self, chunk):
        if self.model == "RS":
            idx = np.take(self.data['users_metadata'], chunk, 0)
            x = np.take(self.data['genome'],
                        np.take(self.data['movies'], idx),
                        0)
            y = np.take(self.data['ratings'], idx)

            return {
                "x": np.append(x, y[:, None], -1) 
            }

        if self.conditional_model:
            return {
                "x": np.append(
                            self.data['x'][chunk],
                            self.data['y'][chunk][:, None],
                            -1) 
            }

        return {'x': self.data['x'][chunk],}

    def _get_encoder_args(self, θ, chunk):
        enc_args = {'θ': θ,}
        enc_args.update(self._get_encoder_args_x(chunk))
        enc_args.update(self._get_encoder_args_mask(chunk))
        return enc_args
    
    def get_params_child(self, params, θ, chunk):
        return self.encoder.apply(params['encoder'],
                            **self._get_encoder_args(θ, chunk))

class BranchDiagonal(BranchGaussian):

    def initial_params(self):
        μθ = np.zeros(self.D_par)
        Lθ = self.scale_transform.inverse_diag_transform(np.ones(self.D_par))
        μw = np.zeros([self.N_chunk, self.D_kid])
        Aw = np.zeros([self.N_chunk, self.D_kid, self.D_par])
        Lw = self.scale_transform.inverse_diag_transform(np.ones((self.N_chunk, self.D_kid)))
        return μθ, Lθ, μw, Aw, Lw

    def get_params_parent(self, params):
        mu = params[0]
        L = self.scale_transform.forward_diag_transform(params[1])
        return mu, L

    def get_params_child(self, params, θ, chunk):
        mu = params[2][chunk]
        A = params[3][chunk]
        L = self.scale_transform.forward_diag_transform(params[4][chunk])
        return mu+A@θ, L 

    def parent_dist(self, params):
        mu, L = self.get_params(params, "parent", None, None)
        return dists.DiagonalNormal(mu, L**2)

    def child_dist(self, θ, params, chunk):
        mu, L = self.get_params(params, "child", θ, chunk)
        return dists.DiagonalNormal(mu, L**2)

class BranchGaussDiag(BranchDiagonal):

    def initial_params(self):
        μθ = np.zeros(self.D_par)
        Lθ = dists.util.matrix_to_tril_vec(
                        self.scale_transform.inverse(np.eye(self.D_par)))
        μw = np.zeros([self.N_chunk, self.D_kid])
        Aw = np.zeros([self.N_chunk, self.D_kid, self.D_par])
        Lw = self.scale_transform.inverse_diag_transform(
                        np.ones((self.N_chunk, self.D_kid)))
        return μθ, Lθ, μw, Aw, Lw

    def get_params_parent(self, params):
        mu = params[0]
        L = self.scale_transform.forward(
                                dists.util.vec_to_tril_matrix(params[1]))
        return mu, L

    def parent_dist(self, params):
        mu, L = self.get_params(params, "parent", None)
        return dists.MultivariateNormal(mu, scale_tril=L)

class BranchDiagGauss(BranchGaussian):

    def initial_params(self):
        μθ = np.zeros(self.D_par)
        Lθ = self.scale_transform.inverse_diag_transform(np.ones(self.D_par))
        μw = np.zeros([self.N_chunk, self.D_kid])
        Aw = np.zeros([self.N_chunk, self.D_kid, self.D_par])
        Lw = dists.util.matrix_to_tril_vec(
                self.scale_transform.inverse(
                    utils.eye_3d(self.D_kid, self.N_chunk)))
        return μθ, Lθ, μw, Aw, Lw

    def get_params_parent(self, params):
        mu = params[0]
        L = self.scale_transform.forward_diag_transform(params[1])
        return mu, L

    def parent_dist(self, params):
        mu, L = self.get_params(params, "parent", None)
        return dists.DiagonalNormal(mu, cov=L**2)
